package com.sequenceiq.cloudbreak.concurrent;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.util.concurrent.locks.Lock;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import com.google.common.util.concurrent.Striped;
import com.sequenceiq.cloudbreak.cloud.event.Payload;
import com.sequenceiq.cloudbreak.cloud.scheduler.CancellationException;
@Component
@Aspect
public class ConcurrentMethodExecutionAspect {
private static final Logger LOGGER = LoggerFactory.getLogger(ConcurrentMethodExecutionAspect.class);
private static final int STRIPES = 10;
private Striped<Lock> locks = Striped.lazyWeakLock(STRIPES);
@Pointcut("execution(@com.sequenceiq.cloudbreak.concurrent.GuardedMethod * *(..))")
public void guardedMethod() {
}
@Pointcut("execution(@com.sequenceiq.cloudbreak.concurrent.LockedMethod * *(..))")
public void lockedMethod() {
}
@Pointcut("args(com.sequenceiq.cloudbreak.cloud.event.Payload)")
public void methodWithPayloadArgument() {
}
@Pointcut("guardedMethod() && methodWithPayloadArgument()")
public void guardedMethodWithPayloadArg() {
}
@Pointcut("lockedMethod() && methodWithPayloadArgument()")
public void lockedMethodWithPayloadArg() {
}
@Around("com.sequenceiq.cloudbreak.concurrent.ConcurrentMethodExecutionAspect.lockedMethodWithPayloadArg()")
public Object executeLockedMethod(ProceedingJoinPoint joinPoint) throws Throwable {
Long stackId = getStackId(joinPoint);
String lockPrefix = getLockedMethodLockPrefix(joinPoint);
String lockKey = createLockKey(lockPrefix, stackId);
Lock lock = locks.get(lockKey);
if (!lock.tryLock()) {
logWaitingOperation(lockPrefix, stackId);
lock.lock();
logContinueOperation(lockPrefix, stackId);
}
try {
return joinPoint.proceed();
} finally {
lock.unlock();
}
}
@Around("com.sequenceiq.cloudbreak.concurrent.ConcurrentMethodExecutionAspect.guardedMethodWithPayloadArg()")
public Object executeGuardedMethod(ProceedingJoinPoint joinPoint) throws Throwable {
Long stackId = getStackId(joinPoint);
String lockPrefix = getGuardedMethodLockPrefix(joinPoint);
String lockKey = createLockKey(lockPrefix, stackId);
Lock lock = locks.get(lockKey);
if (lock.tryLock()) {
try {
return joinPoint.proceed();
} finally {
lock.unlock();
}
} else {
return skipMethodExecution(lockPrefix, stackId);
}
}
private String createLockKey(String lockPrefix, Long stackId) {
return stackId == null ? lockPrefix : lockPrefix + String.valueOf(stackId);
}
private String getGuardedMethodLockPrefix(JoinPoint joinPoint) {
try {
return getAnnotation(GuardedMethod.class, joinPoint).lockPrefix();
} catch (Exception ex) {
return "";
}
}
private String getLockedMethodLockPrefix(JoinPoint joinPoint) {
try {
return getAnnotation(LockedMethod.class, joinPoint).lockPrefix();
} catch (Exception ex) {
return "";
}
}
private <T extends Annotation> T getAnnotation(Class<T> clazz, JoinPoint joinPoint) throws NoSuchMethodException {
MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
Method method = joinPoint.getTarget().getClass().getDeclaredMethod(joinPoint.getSignature().getName(),
methodSignature.getMethod().getParameterTypes());
return method.getAnnotation(clazz);
}
private Long getStackId(JoinPoint joinPoint) {
Payload payload = getPayload(joinPoint);
return payload == null ? null : payload.getStackId();
}
private Payload getPayload(JoinPoint joinPoint) {
Payload payload = null;
for (Object arg : joinPoint.getArgs()) {
if (arg instanceof Payload) {
payload = (Payload) arg;
}
}
return payload;
}
private Object skipMethodExecution(String lockPrefix, Long stackId) {
String message;
if (stackId != null) {
message = String.format("%s operation will be skipped on stack %d, because it is running on a different thread.", lockPrefix, stackId);
} else {
message = String.format("%s operation will be skipped, because it is running on a different thread.", lockPrefix);
}
LOGGER.info(message);
throw new CancellationException(message);
}
private void logWaitingOperation(String lockPrefix, Long stackId) {
if (stackId != null) {
LOGGER.info("Waiting for other {} operation on stack {} to be finished.", lockPrefix, stackId);
} else {
LOGGER.info("Waiting for other {} operation to be finished.", lockPrefix);
}
}
private void logContinueOperation(String lockPrefix, Long stackId) {
if (stackId != null) {
LOGGER.info("Continue {} operation on stack {}.", lockPrefix, stackId);
} else {
LOGGER.info("Continue {} operation.", lockPrefix);
}
}
}